"""ResNet v1, v2, and segmentation models for Keras.



# Reference



- [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385)

- [Identity Mappings in Deep Residual Networks](https://arxiv.org/abs/1603.05027)



Reference material for extended functionality:



- [ResNeXt](https://arxiv.org/abs/1611.05431) for Tiny ImageNet support.

- [Dilated Residual Networks](https://arxiv.org/pdf/1705.09914) for segmentation support

- [Deep Residual Learning for Instrument Segmentation in

   Robotic Surgery](https://arxiv.org/abs/1703.08580)

  for segmentation support.



Implementation Adapted from: github.com/raghakot/keras-resnet

"""  # pylint: disable=E501

from __future__ import division



import six

from keras.models import Model

from keras.layers import Input

from keras.layers import Activation

from keras.layers import Reshape

from keras.layers import Dense

from keras.layers import Conv2D

from keras.layers import MaxPooling2D

from keras.layers import GlobalMaxPooling2D

from keras.layers import GlobalAveragePooling2D

from keras.layers import Dropout

from keras.layers.merge import add

from keras.layers.normalization import BatchNormalization

from keras.regularizers import l2

from keras import backend as K

from keras_applications.imagenet_utils import _obtain_input_shape





def _bn_relu(x, bn_name=None, relu_name=None):

    """Helper to build a BN -> relu block

    """

    norm = BatchNormalization(axis=CHANNEL_AXIS, name=bn_name)(x)

    return Activation("relu", name=relu_name)(norm)





def _conv_bn_relu(**conv_params):

    """Helper to build a conv -> BN -> relu residual unit activation function.

       This is the original ResNet v1 scheme in https://arxiv.org/abs/1512.03385

    """

    filters = conv_params["filters"]

    kernel_size = conv_params["kernel_size"]

    strides = conv_params.setdefault("strides", (1, 1))

    dilation_rate = conv_params.setdefault("dilation_rate", (1, 1))

    conv_name = conv_params.setdefault("conv_name", None)

    bn_name = conv_params.setdefault("bn_name", None)

    relu_name = conv_params.setdefault("relu_name", None)

    kernel_initializer = conv_params.setdefault("kernel_initializer", "random_normal")

    padding = conv_params.setdefault("padding", "same")

    kernel_regularizer = conv_params.setdefault("kernel_regularizer", l2(1.e-4))



    def f(x):

        x = Conv2D(filters=filters, kernel_size=kernel_size,

                   strides=strides, padding=padding,

                   dilation_rate=dilation_rate,

                   kernel_initializer=kernel_initializer,

                   kernel_regularizer=kernel_regularizer,

                   name=conv_name)(x)

        return _bn_relu(x, bn_name=bn_name, relu_name=relu_name)



    return f





def _bn_relu_conv(**conv_params):

    """Helper to build a BN -> relu -> conv residual unit with full pre-activation

    function. This is the ResNet v2 scheme proposed in

    http://arxiv.org/pdf/1603.05027v2.pdf

    """

    filters = conv_params["filters"]

    kernel_size = conv_params["kernel_size"]

    strides = conv_params.setdefault("strides", (1, 1))

    dilation_rate = conv_params.setdefault("dilation_rate", (1, 1))

    conv_name = conv_params.setdefault("conv_name", None)

    bn_name = conv_params.setdefault("bn_name", None)

    relu_name = conv_params.setdefault("relu_name", None)

    kernel_initializer = conv_params.setdefault("kernel_initializer", "random_normal")

    padding = conv_params.setdefault("padding", "same")

    kernel_regularizer = conv_params.setdefault("kernel_regularizer", l2(1.e-4))



    def f(x):

        activation = _bn_relu(x, bn_name=bn_name, relu_name=relu_name)

        return Conv2D(filters=filters, kernel_size=kernel_size,

                      strides=strides, padding=padding,

                      dilation_rate=dilation_rate,

                      kernel_initializer=kernel_initializer,

                      kernel_regularizer=kernel_regularizer,

                      name=conv_name)(activation)



    return f





def _shortcut(input_feature, residual, conv_name_base=None, bn_name_base=None):

    """Adds a shortcut between input and residual block and merges them with "sum"

    """

    # Expand channels of shortcut to match residual.

    # Stride appropriately to match residual (width, height)

    # Should be int if network architecture is correctly configured.

    input_shape = K.int_shape(input_feature)

    residual_shape = K.int_shape(residual)

    stride_width = int(round(input_shape[ROW_AXIS] / residual_shape[ROW_AXIS]))

    stride_height = int(round(input_shape[COL_AXIS] / residual_shape[COL_AXIS]))

    equal_channels = input_shape[CHANNEL_AXIS] == residual_shape[CHANNEL_AXIS]



    shortcut = input_feature

    # 1 X 1 conv if shape is different. Else identity.

    if stride_width > 1 or stride_height > 1 or not equal_channels:

        print('reshaping via a convolution...')

        if conv_name_base is not None:

            conv_name_base = conv_name_base + '1'

        shortcut = Conv2D(filters=residual_shape[CHANNEL_AXIS],

                          kernel_size=(1, 1),

                          strides=(stride_width, stride_height),

                          padding="valid",

                          kernel_initializer="he_normal",

                          kernel_regularizer=l2(0.0001),

                          name=conv_name_base)(input_feature)

        if bn_name_base is not None:

            bn_name_base = bn_name_base + '1'

        shortcut = BatchNormalization(axis=CHANNEL_AXIS,

                                      name=bn_name_base)(shortcut)



    return add([shortcut, residual])





def _residual_block(block_function, filters, blocks, stage,

                    transition_strides=None, transition_dilation_rates=None,

                    dilation_rates=None, is_first_layer=False, dropout=None,

                    residual_unit=_bn_relu_conv):

    """Builds a residual block with repeating bottleneck blocks.



       stage: integer, current stage label, used for generating layer names

       blocks: number of blocks 'a','b'..., current block label, used for generating

            layer names

       transition_strides: a list of tuples for the strides of each transition

       transition_dilation_rates: a list of tuples for the dilation rate of each

            transition

    """

    if transition_dilation_rates is None:

        transition_dilation_rates = [(1, 1)] * blocks

    if transition_strides is None:

        transition_strides = [(1, 1)] * blocks

    if dilation_rates is None:

        dilation_rates = [1] * blocks



    def f(x):

        for i in range(blocks):

            is_first_block = is_first_layer and i == 0

            x = block_function(filters=filters, stage=stage, block=i,

                               transition_strides=transition_strides[i],

                               dilation_rate=dilation_rates[i],

                               is_first_block_of_first_layer=is_first_block,

                               dropout=dropout,

                               residual_unit=residual_unit)(x)

        return x



    return f





def _block_name_base(stage, block):

    """Get the convolution name base and batch normalization name base defined by

    stage and block.



    If there are less than 26 blocks they will be labeled 'a', 'b', 'c' to match the

    paper and keras and beyond 26 blocks they will simply be numbered.

    """

    if block < 27:

        block = '%c' % (block + 97)  # 97 is the ascii number for lowercase 'a'

    conv_name_base = 'res' + str(stage) + block + '_branch'

    bn_name_base = 'bn' + str(stage) + block + '_branch'

    return conv_name_base, bn_name_base





def basic_block(filters, stage, block, transition_strides=(1, 1),

                dilation_rate=(1, 1), is_first_block_of_first_layer=False, dropout=None,

                residual_unit=_bn_relu_conv):

    """Basic 3 X 3 convolution blocks for use on resnets with layers <= 34.

    Follows improved proposed scheme in http://arxiv.org/pdf/1603.05027v2.pdf

    """

    def f(input_features):

        conv_name_base, bn_name_base = _block_name_base(stage, block)

        if is_first_block_of_first_layer:

            # don't repeat bn->relu since we just did bn->relu->maxpool

            x = Conv2D(filters=filters, kernel_size=(3, 3),

                       strides=transition_strides,

                       dilation_rate=dilation_rate,

                       padding="same",

                       kernel_initializer="he_normal",

                       kernel_regularizer=l2(1e-4),

                       name=conv_name_base + '2a')(input_features)

        else:

            x = residual_unit(filters=filters, kernel_size=(3, 3),

                              strides=transition_strides,

                              dilation_rate=dilation_rate,

                              conv_name_base=conv_name_base + '2a',

                              bn_name_base=bn_name_base + '2a')(input_features)



        if dropout is not None:

            x = Dropout(dropout)(x)



        x = residual_unit(filters=filters, kernel_size=(3, 3),

                          conv_name_base=conv_name_base + '2b',

                          bn_name_base=bn_name_base + '2b')(x)



        return _shortcut(input_features, x)



    return f





def bottleneck(filters, stage, block, transition_strides=(1, 1),

               dilation_rate=(1, 1), is_first_block_of_first_layer=False, dropout=None,

               residual_unit=_bn_relu_conv):

    """Bottleneck architecture for > 34 layer resnet.

    Follows improved proposed scheme in http://arxiv.org/pdf/1603.05027v2.pdf



    Returns:

        A final conv layer of filters * 4

    """

    def f(input_feature):

        conv_name_base, bn_name_base = _block_name_base(stage, block)

        if is_first_block_of_first_layer:

            # don't repeat bn->relu since we just did bn->relu->maxpool

            x = Conv2D(filters=filters, kernel_size=(1, 1),

                       strides=transition_strides,

                       dilation_rate=dilation_rate,

                       padding="same",

                       kernel_initializer="he_normal",

                       kernel_regularizer=l2(1e-4),

                       name=conv_name_base + '2a')(input_feature)

        else:

            x = residual_unit(filters=filters, kernel_size=(1, 1),

                              strides=transition_strides,

                              dilation_rate=dilation_rate,

                              conv_name_base=conv_name_base + '2a',

                              bn_name_base=bn_name_base + '2a')(input_feature)



        if dropout is not None:

            x = Dropout(dropout)(x)



        x = residual_unit(filters=filters, kernel_size=(3, 3),

                          conv_name_base=conv_name_base + '2b',

                          bn_name_base=bn_name_base + '2b')(x)



        if dropout is not None:

            x = Dropout(dropout)(x)



        x = residual_unit(filters=filters * 4, kernel_size=(1, 1),

                          conv_name_base=conv_name_base + '2c',

                          bn_name_base=bn_name_base + '2c')(x)



        return _shortcut(input_feature, x)



    return f





def _handle_dim_ordering():

    global ROW_AXIS

    global COL_AXIS

    global CHANNEL_AXIS

    if K.image_data_format() == 'channels_last':

        ROW_AXIS = 1

        COL_AXIS = 2

        CHANNEL_AXIS = 3

    else:

        CHANNEL_AXIS = 1

        ROW_AXIS = 2

        COL_AXIS = 3





def _string_to_function(identifier):

    if isinstance(identifier, six.string_types):

        res = globals().get(identifier)

        if not res:

            raise ValueError('Invalid {}'.format(identifier))

        return res

    return identifier





def ResNet(input_shape=None, classes=10, block='bottleneck', residual_unit='v2',

           repetitions=None, initial_filters=64, activation='softmax', include_top=True,

           input_tensor=None, dropout=None, transition_dilation_rate=(1, 1),

           initial_strides=(2, 2), initial_kernel_size=(7, 7), initial_pooling='max',

           final_pooling=None, top='classification'):

    """Builds a custom ResNet like architecture. Defaults to ResNet50 v2.



    Args:

        input_shape: optional shape tuple, only to be specified

            if `include_top` is False (otherwise the input shape

            has to be `(224, 224, 3)` (with `channels_last` dim ordering)

            or `(3, 224, 224)` (with `channels_first` dim ordering).

            It should have exactly 3 dimensions,

            and width and height should be no smaller than 8.

            E.g. `(224, 224, 3)` would be one valid value.

        classes: The number of outputs at final softmax layer

        block: The block function to use. This is either `'basic'` or `'bottleneck'`.

            The original paper used `basic` for layers < 50.

        repetitions: Number of repetitions of various block units.

            At each block unit, the number of filters are doubled and the input size

            is halved. Default of None implies the ResNet50v2 values of [3, 4, 6, 3].

        residual_unit: the basic residual unit, 'v1' for conv bn relu, 'v2' for bn relu

            conv. See [Identity Mappings in

            Deep Residual Networks](https://arxiv.org/abs/1603.05027)

            for details.

        dropout: None for no dropout, otherwise rate of dropout from 0 to 1.

            Based on [Wide Residual Networks.(https://arxiv.org/pdf/1605.07146) paper.

        transition_dilation_rate: Dilation rate for transition layers. For semantic

            segmentation of images use a dilation rate of (2, 2).

        initial_strides: Stride of the very first residual unit and MaxPooling2D call,

            with default (2, 2), set to (1, 1) for small images like cifar.

        initial_kernel_size: kernel size of the very first convolution, (7, 7) for

            imagenet and (3, 3) for small image datasets like tiny imagenet and cifar.

            See [ResNeXt](https://arxiv.org/abs/1611.05431) paper for details.

        initial_pooling: Determine if there will be an initial pooling layer,

            'max' for imagenet and None for small image datasets.

            See [ResNeXt](https://arxiv.org/abs/1611.05431) paper for details.

        final_pooling: Optional pooling mode for feature extraction at the final

            model layer when `include_top` is `False`.

            - `None` means that the output of the model

                will be the 4D tensor output of the

                last convolutional layer.

            - `avg` means that global average pooling

                will be applied to the output of the

                last convolutional layer, and thus

                the output of the model will be a

                2D tensor.

            - `max` means that global max pooling will

                be applied.

        top: Defines final layers to evaluate based on a specific problem type. Options

            are 'classification' for ImageNet style problems, 'segmentation' for

            problems like the Pascal VOC dataset, and None to exclude these layers

            entirely.



    Returns:

        The keras `Model`.

    """

    if activation not in ['softmax', 'sigmoid', None]:

        raise ValueError('activation must be one of "softmax", "sigmoid", or None')

    if activation == 'sigmoid' and classes != 1:

        raise ValueError('sigmoid activation can only be used when classes = 1')

    if repetitions is None:

        repetitions = [3, 4, 6, 3]

    # Determine proper input shape

    input_shape = _obtain_input_shape(input_shape,

                                      default_size=32,

                                      min_size=8,

                                      data_format=K.image_data_format(),

                                      require_flatten=include_top)

    _handle_dim_ordering()

    if len(input_shape) != 3:

        raise Exception("Input shape should be a tuple (nb_channels, nb_rows, nb_cols)")



    if block == 'basic':

        block_fn = basic_block

    elif block == 'bottleneck':

        block_fn = bottleneck

    elif isinstance(block, six.string_types):

        block_fn = _string_to_function(block)

    else:

        block_fn = block



    if residual_unit == 'v2':

        residual_unit = _bn_relu_conv

    elif residual_unit == 'v1':

        residual_unit = _conv_bn_relu

    elif isinstance(residual_unit, six.string_types):

        residual_unit = _string_to_function(residual_unit)

    else:

        residual_unit = residual_unit



    # Permute dimension order if necessary

    if K.image_data_format() == 'channels_first':

        input_shape = (input_shape[1], input_shape[2], input_shape[0])

    # Determine proper input shape

    input_shape = _obtain_input_shape(input_shape,

                                      default_size=32,

                                      min_size=8,

                                      data_format=K.image_data_format(),

                                      require_flatten=include_top)



    img_input = Input(shape=input_shape, tensor=input_tensor)

    x = _conv_bn_relu(filters=initial_filters, kernel_size=initial_kernel_size,

                      strides=initial_strides)(img_input)

    if initial_pooling == 'max':

        x = MaxPooling2D(pool_size=(3, 3), strides=initial_strides, padding="same")(x)



    block = x

    filters = initial_filters

    for i, r in enumerate(repetitions):

        transition_dilation_rates = [transition_dilation_rate] * r

        transition_strides = [(1, 1)] * r

        if transition_dilation_rate == (1, 1):

            transition_strides[0] = (2, 2)

        block = _residual_block(block_fn, filters=filters,

                                stage=i, blocks=r,

                                is_first_layer=(i == 0),

                                dropout=dropout,

                                transition_dilation_rates=transition_dilation_rates,

                                transition_strides=transition_strides,

                                residual_unit=residual_unit)(block)

        filters *= 2



    # Last activation

    inter = _bn_relu(block)



    # Classifier block

    if include_top and top is 'classification':

        x = GlobalAveragePooling2D()(inter)

        x = Dense(units=classes, activation=activation,

                  kernel_initializer="he_normal")(x)

    elif include_top and top is 'segmentation':

        x = Conv2D(classes, (1, 1), activation='linear', padding='same')(inter)



        if K.image_data_format() == 'channels_first':

            channel, row, col = input_shape

        else:

            row, col, channel = input_shape



        x = Reshape((row * col, classes))(x)

        x = Activation(activation)(x)

        x = Reshape((row, col, classes))(x)

    elif final_pooling == 'avg':

        x = GlobalAveragePooling2D()(inter)

    elif final_pooling == 'max':

        x = GlobalMaxPooling2D()(inter)



    model = Model(inputs=img_input, outputs=x)

    return model





def ResNet18(input_shape, classes):

    """ResNet with 18 layers and v2 residual units

    """

    return ResNet(input_shape, classes, basic_block, repetitions=[2, 2, 2, 2])





def ResNet34(input_shape, classes):

    """ResNet with 34 layers and v2 residual units

    """

    return ResNet(input_shape, classes, basic_block, repetitions=[3, 4, 6, 3])





def ResNet50(input_shape, classes):

    """ResNet with 50 layers and v2 residual units

    """

    return ResNet(input_shape, classes, bottleneck, repetitions=[3, 4, 6, 3])





def ResNet101(input_shape, classes):

    """ResNet with 101 layers and v2 residual units

    """

    return ResNet(input_shape, classes, bottleneck, repetitions=[3, 4, 23, 3])





def ResNet152(input_shape, classes):

    """ResNet with 152 layers and v2 residual units

    """

    return ResNet(input_shape, classes, bottleneck, repetitions=[3, 8, 36, 3])